{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tFhqrPopctwa"
      },
      "source": [
        "Copyright 2020 DeepMind Technologies Limited.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zuYtT7GpuVgQ"
      },
      "source": [
        "\u003cp align=\"center\"\u003e\n",
        "  \u003ch1 align=\"center\"\u003eTAPIR: Tracking Any Point with per-frame Initialization and temporal Refinement\u003c/h1\u003e\n",
        "  \u003cp align=\"center\"\u003e\n",
        "    \u003ca href=\"http://www.carldoersch.com/\"\u003eCarl Doersch\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://yangyi02.github.io/\"\u003eYi Yang\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://scholar.google.com/citations?user=Jvi_XPAAAAAJ\"\u003eMel Vecerik\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://scholar.google.com/citations?user=cnbENAEAAAAJ\"\u003eDilara Gokay\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://www.robots.ox.ac.uk/~ankush/\"\u003eAnkush Gupta\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"http://people.csail.mit.edu/yusuf/\"\u003eYusuf Aytar\u003c/a\u003e\n",
        "    \u003ca href=\"https://scholar.google.co.uk/citations?user=IUZ-7_cAAAAJ\"\u003eJoao Carreira\u003c/a\u003e\n",
        "    ·\n",
        "    \u003ca href=\"https://www.robots.ox.ac.uk/~az/\"\u003eAndrew Zisserman\u003c/a\u003e\n",
        "  \u003c/p\u003e\n",
        "  \u003ch3 align=\"center\"\u003e\u003ca href=\"https://arxiv.org/abs/2306.08637\"\u003ePaper\u003c/a\u003e | \u003ca href=\"https://deepmind-tapir.github.io\"\u003eProject Page\u003c/a\u003e | \u003ca href=\"https://github.com/deepmind/tapnet\"\u003eGitHub\u003c/a\u003e | \u003ca href=\"https://github.com/deepmind/tapnet/tree/main#running-tapir-locally\"\u003eLive Demo\u003c/a\u003e \u003c/h3\u003e\n",
        "  \u003cdiv align=\"center\"\u003e\u003c/div\u003e\n",
        "\u003c/p\u003e\n",
        "\n",
        "\u003cp align=\"center\"\u003e\n",
        "  \u003ca href=\"\"\u003e\n",
        "    \u003cimg src=\"https://storage.googleapis.com/dm-tapnet/swaying_gif.gif\" alt=\"Logo\" width=\"50%\"\u003e\n",
        "  \u003c/a\u003e\n",
        "\u003c/p\u003e\n",
        "\u003ch3\u003e\n",
        "\u003c/h3\u003e"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mCmDvfFvxnGB"
      },
      "outputs": [],
      "source": [
        "# @title Install code and dependencies {form-width: \"25%\"}\n",
        "!pip install git+https://github.com/google-deepmind/tapnet.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HaswJZMq9B3c"
      },
      "outputs": [],
      "source": [
        "# @title Download Model {form-width: \"25%\"}\n",
        "\n",
        "%mkdir tapnet/checkpoints\n",
        "\n",
        "!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/causal_bootstapir_checkpoint.pt\n",
        "\n",
        "%ls tapnet/checkpoints"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FxlHY242m-6Q"
      },
      "outputs": [],
      "source": [
        "# @title Imports {form-width: \"25%\"}\n",
        "%matplotlib widget\n",
        "import haiku as hk\n",
        "import jax\n",
        "import matplotlib.pyplot as plt\n",
        "import mediapy as media\n",
        "import numpy as np\n",
        "import tree\n",
        "\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "\n",
        "from tapnet.torch import tapir_model\n",
        "from tapnet.utils import transforms\n",
        "from tapnet.utils import viz_utils\n",
        "\n",
        "from google.colab import output\n",
        "output.enable_custom_widget_manager()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "swWFHeHFrsT0"
      },
      "outputs": [],
      "source": [
        "# @title Select Device {form-width: \"25%\"}\n",
        "\n",
        "if torch.cuda.is_available():\n",
        "  device = torch.device('cuda')\n",
        "else:\n",
        "  device = torch.device('cpu')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7rfy2yobnHqw"
      },
      "outputs": [],
      "source": [
        "# @title Load an Exemplar Video {form-width: \"25%\"}\n",
        "\n",
        "%mkdir tapnet/examplar_videos\n",
        "\n",
        "!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4\n",
        "\n",
        "video = media.read_video('tapnet/examplar_videos/horsejump-high.mp4')\n",
        "height, width = video.shape[1:3]\n",
        "media.show_video(video, fps=10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hhaB3rBssLLc"
      },
      "outputs": [],
      "source": [
        "# @title Utility Functions {form-width: \"25%\"}\n",
        "\n",
        "def preprocess_frames(frames):\n",
        "  \"\"\"Preprocess frames to model inputs.\n",
        "\n",
        "  Args:\n",
        "    frames: [num_frames, height, width, 3], [0, 255], np.uint8\n",
        "\n",
        "  Returns:\n",
        "    frames: [num_frames, height, width, 3], [-1, 1], np.float32\n",
        "  \"\"\"\n",
        "  frames = frames.float()\n",
        "  frames = frames / 255 * 2 - 1\n",
        "  return frames\n",
        "\n",
        "\n",
        "def sample_random_points(frame_max_idx, height, width, num_points):\n",
        "  \"\"\"Sample random points with (time, height, width) order.\"\"\"\n",
        "  y = np.random.randint(0, height, (num_points, 1))\n",
        "  x = np.random.randint(0, width, (num_points, 1))\n",
        "  t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))\n",
        "  points = np.concatenate((t, y, x), axis=-1).astype(np.int32)  # [num_points, 3]\n",
        "  return points\n",
        "\n",
        "\n",
        "def postprocess_occlusions(occlusions, expected_dist):\n",
        "  visibles = (1 - F.sigmoid(occlusions)) * (1 - F.sigmoid(expected_dist)) \u003e 0.5\n",
        "  return visibles"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I7wOMJoSQzq1"
      },
      "outputs": [],
      "source": [
        "# @title Define Model {form-width: \"25%\"}\n",
        "\n",
        "model = tapir_model.TAPIR(pyramid_level=1, use_casual_conv=True)\n",
        "\n",
        "model.load_state_dict(torch.load('tapnet/checkpoints/causal_bootstapir_checkpoint.pt'))\n",
        "model = model.to(device)\n",
        "model = model.eval()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Izp33JBg6eij"
      },
      "outputs": [],
      "source": [
        "def online_model_init(frames, query_points):\n",
        "  \"\"\"Initialize query features for the query points.\"\"\"\n",
        "  frames = preprocess_frames(frames)\n",
        "  feature_grids = model.get_feature_grids(frames, is_training=False)\n",
        "  query_features = model.get_query_features(\n",
        "      frames,\n",
        "      is_training=False,\n",
        "      query_points=query_points,\n",
        "      feature_grids=feature_grids,\n",
        "  )\n",
        "  return query_features\n",
        "\n",
        "def online_model_predict(frames, query_features, causal_context):\n",
        "  \"\"\"Compute point tracks and occlusions given frames and query points.\"\"\"\n",
        "  frames = preprocess_frames(frames)\n",
        "  feature_grids = model.get_feature_grids(frames, is_training=False)\n",
        "  trajectories = model.estimate_trajectories(\n",
        "      frames.shape[-3:-1],\n",
        "      is_training=False,\n",
        "      feature_grids=feature_grids,\n",
        "      query_features=query_features,\n",
        "      query_points_in_video=None,\n",
        "      query_chunk_size=64,\n",
        "      causal_context=causal_context,\n",
        "      get_causal_context=True,\n",
        "  )\n",
        "  causal_context = trajectories['causal_context']\n",
        "  del trajectories['causal_context']\n",
        "  # Take only the predictions for the final resolution.\n",
        "  # For running on higher resolution, it's typically better to average across\n",
        "  # resolutions.\n",
        "  tracks = trajectories['tracks'][-1]\n",
        "  occlusions = trajectories['occlusion'][-1]\n",
        "  uncertainty = trajectories['expected_dist'][-1]\n",
        "  visibles = postprocess_occlusions(occlusions, uncertainty)\n",
        "  return tracks, visibles, causal_context"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_LLK7myqp3Px"
      },
      "outputs": [],
      "source": [
        "# @title Predict Sparse Point Tracks {form-width: \"25%\"}\n",
        "\n",
        "resize_height = 256  # @param {type: \"integer\"}\n",
        "resize_width = 256  # @param {type: \"integer\"}\n",
        "num_points = 50  # @param {type: \"integer\"}\n",
        "\n",
        "frames = media.resize_video(video, (resize_height, resize_width))\n",
        "query_points = sample_random_points(0, frames.shape[1], frames.shape[2], num_points)\n",
        "frames = torch.tensor(frames).to(device)\n",
        "query_points = torch.tensor(query_points).to(device)\n",
        "\n",
        "query_features = online_model_init(frames[None, 0:1],\n",
        "                                   query_points[None])\n",
        "\n",
        "causal_state = model.construct_initial_causal_state(query_points.shape[0],\n",
        "                                                    len(query_features.resolutions) - 1)\n",
        "\n",
        "with torch.no_grad():\n",
        "  for i in range(len(causal_state)):\n",
        "    for k,v in causal_state[i].items():\n",
        "      causal_state[i][k] = v.to(device)\n",
        "\n",
        "  # Predict point tracks frame by frame\n",
        "  predictions = []\n",
        "  for i in range(frames.shape[0]):\n",
        "    # Note: we add a batch dimension.\n",
        "    tracks, visibles, causal_state = online_model_predict(\n",
        "        frames=frames[None, i:i+1],\n",
        "        query_features=query_features,\n",
        "        causal_context=causal_state,\n",
        "    )\n",
        "    predictions.append({'tracks':tracks, 'visibles':visibles})\n",
        "\n",
        "tracks = torch.cat([x['tracks'][0] for x in predictions], dim=1)\n",
        "visibles = torch.cat([x['visibles'][0] for x in predictions], dim=1)\n",
        "\n",
        "tracks = tracks.cpu().numpy()\n",
        "visibles = visibles.cpu().numpy()\n",
        "# Visualize sparse point tracks\n",
        "tracks = transforms.convert_grid_coordinates(tracks, (resize_width, resize_height), (width, height))\n",
        "video_viz = viz_utils.paint_point_track(video, tracks, visibles)\n",
        "media.show_video(video_viz, fps=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8orfQRoaRJit"
      },
      "source": [
        "That's it!"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "https://github.com/deepmind/tapnet/blob/master/colabs/torch_causal_tapir_demo.ipynb",
          "timestamp": 1720431514861
        }
      ]
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
